{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyG构建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2708/2708 [00:00<00:00, 2799.36it/s]\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "生成Graph的节点数据集，使用Cora数据集，以及torch_geometric.data.Data方法，\n",
    "需要参数如下：\n",
    "    x : torch.Tensor, 节点特征矩阵，shape为[num_nodes, num_node_features]\n",
    "    edge_index : LongTensor, Graph的连接矩阵，shape为[2, num_edges]\n",
    "    edge_attr : None, 暂不需要\n",
    "    y : Tensor, 图或节点的标签，shape任意\n",
    "    pos : Tensor, 暂不需要\n",
    "    norm : Tensor, 暂不需要\n",
    "    face : LongTensor, 暂不需要\n",
    "'''\n",
    "content_path = \"./cora/cora.content\"\n",
    "cite_path = \"./cora/cora.cites\"\n",
    "\n",
    "# 读取文本内容\n",
    "with open(content_path, \"r\") as fp:\n",
    "    contents = fp.readlines()\n",
    "with open(cite_path, \"r\") as fp:\n",
    "    cites = fp.readlines()\n",
    "\n",
    "# 边列表\n",
    "cites = list(map(lambda x: x.strip().split(\"\\t\"), cites))\n",
    "\n",
    "# 构建映射字典\n",
    "paper_list, feat_list, label_list = [], [], []\n",
    "for line in tqdm(contents):\n",
    "    tag, *feat, label = line.strip().split(\"\\t\")\n",
    "    paper_list.append(tag)\n",
    "    feat_list.append(np.array(list(map(lambda x: int(x), feat))))\n",
    "    label_list.append(label)\n",
    "# Paper -> Index 字典\n",
    "paper_dict = dict([(key, val) for val, key in enumerate(paper_list)])\n",
    "# Label -> Index 字典\n",
    "label_dict = dict([(key, val) for val, key in enumerate(set(label_list))])\n",
    "# Edge_index构建\n",
    "cites = np.array([[paper_dict[i[0]], paper_dict[i[1]]] for i in cites], \n",
    "                 np.int64).T                                 # (2, edge)\n",
    "cites = np.concatenate((cites, cites[::-1, :]), axis=1)      # (2, 2*edge), 即(2, E)\n",
    "# y 构建\n",
    "y = np.array([label_dict[i] for i in label_list])\n",
    "# Input 构建\n",
    "x = torch.from_numpy(np.array(feat_list, dtype=np.float32))  # [N, Feat_Dim]\n",
    "edge_index = torch.from_numpy(cites)                         # [E, 2]\n",
    "y = torch.from_numpy(y)                               # [N, ]\n",
    "\n",
    "# 构建Data类\n",
    "data = Data(x=x,\n",
    "            edge_index=edge_index,\n",
    "            y=y)\n",
    "# 分割数据集\n",
    "data.train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)\n",
    "data.train_mask[:data.num_nodes - 1000] = 1                  # 1700左右training\n",
    "data.val_mask = None                                         # 0valid\n",
    "data.test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)\n",
    "data.test_mask[data.num_nodes - 500:] = 1                    # 500test\n",
    "data.num_classes = len(label_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 打印数据信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(edge_index=[2, 10858], num_classes=[1], test_mask=[2708], train_mask=[2708], x=[2708, 1433], y=[2708])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********************Data Info********************\n",
      "==> Is undirected graph : True\n",
      "==> Number of edges : 10858/2=5429\n",
      "==> Number of nodes : 2708\n",
      "==> Node feature dim : 1433\n",
      "==> Number of training nodes : 1708\n",
      "==> Number of testing nodes : 500\n",
      "==> Number of classes : 7\n"
     ]
    }
   ],
   "source": [
    "print(\"{}Data Info{}\".format(\"*\"*20, \"*\"*20))\n",
    "print(\"==> Is undirected graph : {}\".format(data.is_undirected()))\n",
    "print(\"==> Number of edges : {}/2={}\".format(data.num_edges, int(data.num_edges/2)))\n",
    "print(\"==> Number of nodes : {}\".format(data.num_nodes))\n",
    "print(\"==> Node feature dim : {}\".format(data.num_node_features))\n",
    "print(\"==> Number of training nodes : {}\".format(data.train_mask.sum().item()))\n",
    "print(\"==> Number of testing nodes : {}\".format(data.test_mask.sum().item()))\n",
    "print(\"==> Number of classes : {}\".format(data.num_classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********************Train Data anf Test Data Info********************\n",
      "==> Label list : \n",
      "    (0, 'Genetic_Algorithms')\n",
      "    (1, 'Neural_Networks')\n",
      "    (2, 'Rule_Learning')\n",
      "    (3, 'Reinforcement_Learning')\n",
      "    (4, 'Probabilistic_Methods')\n",
      "    (5, 'Case_Based')\n",
      "    (6, 'Theory')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Test Data Statics')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XvcVGW5//HPl4OAeEAFjTgIInlIU4jSPFSK9bM8lZblLkFTyb3bVltNsV/qdttOt5mW7l3JVkPMA2j1ktSfbQPN3AmK53OAKDxBgqGYKSly/f5Y9wPDw3oeBmbWzDzPfN+v17xm1r3W3Ot64Jq5Zp3upYjAzMysrW71DsDMzBqTC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLheIGpDUXdIbkobWO5ZGJel/JH2p3nGYVYOkmZK+UO84KuUCkSN9mbc+Vkt6q2R6o7/EIuLdiNgiIhZuQiw7S4qS9f9Z0q8ljd2IPk6WdO/Grrvk/ZL0HUkvphhaJN1QMv9+SSdsRH/flTS5tC0iPhkRN7TzFquyaud4Sb+zJH25g/m75uTzdEkHbcQ6TpX02wpilKTzS/J5kaQp5f4NOf1dLOnq0raIODgipm5qjI3CBSJH+jLfIiK2ABYCR5S0rfclJqlHrWICRgEzgekbk8QV+grwReDgFMOHgHtrtG4rwMbmeJW92yaf7wNul/TFgtfbagJwDHBQimGfFIO1FRF+dPAAXgQOadP2XWAqcBPwV+AE4CPALOA1YAlwBdAzLd8DCGBYmv55mv//0vsfAIa3s/6ds/+m9donAosBpenvAC+k/p4GjkztewIrgXeBN4BXUvuRwGNp+YXAuR38G/wUuLSdef+R+l6Z+v9hav9PoAV4HXgI2C+1Hw68DbyTln84td8PnFDS71eB51J8TwF7pfZvp7/79TT/4/XOkc7+aCfHuwPnppx6BbgB6Jfm9QVuBpanfJ8NbAP8oE0u/CBnXbsCq3LavwMsKpk+D1hQ8v9/WGoflfpfldbx59T+WeDxlBcvAd/u4O+9Gri4nXm5fwPwk5J8fhDYN7V/pk0+P5jaZwFfLun3n0ry+Ulgz9R+Ltn3xevAs8CB9c6Hdf496h1Aoz/a+fB8NyXFEWRbYX3IflXvQ1YMdgL+CPxzWj6vQLwCjAF6khWbn7ez/vYKxPtSnyPT9LHAwBTPP6Rk3SHNOxm4t837Dwb2SMvvleI5vJ0YTgD+ApwJfBDo3mb+Ol/uqe14YNv0t58N/AnoVfLvN7m9PoDjgEVpXUp/6xDg/enD/5603HBgp3rnSGd/tJPjE4HfA+8FegOTgZ+led8Abk153yPlft80b50vxpx1tVcgdk/5PDxNf6Ekn48n+2Ltn+adCvy2zfvHpvzoBowmK16HthPDycAy4PS0bNt8Xu9vAMaRFcGewP9N+dn6A/Bi4Or2+kjxv0RW3ATsAgxOn7sXgB1S+06080OxXg/vYtp090fEryNidUS8FREPRcTsiFgVES8Ak4CPdfD+WyNiTkS8Q/brbO+NXP/i9LwtQERMi4glKZ4byT70Y9p7c0TMjIin0vKPk/0izI03IiYD3wQ+RbYpvlTSmR0FFxHXR8TyiFgFXAJsRVbsynEy2S+8hyPzx4hYRParsTfwfkk9ImJB+re26vsqMDEiFkfESuAC4AuSRPZreQAwIuX7QxHxtwrX1zafp5bk8/VkPzA+2N6bI2JGRDydln8EmEb7n79ryH7sHEH2w+RlSf/SUXARMSUiXk2f1+8B25F9oZfjZOB7EfFoyufnI6KFLJ/7kBXH7hHxQkQsKLPPmnCB2HSLSifSwbc70kG314F/A/p38P4/l7x+E9hiI9c/KD0vT+s/QdLjkl6T9BrZL7V21y/pI5LulbRM0gqyJG53+fSFPxboB3wNuKijA+WSzpL0XOr7VbLdEh39e5QaAszPieF54Ayyf9ulkm6S9J4y+7QypSIwBLizJJ8eJfu+2I7sC/Z3wK3phIXvSepe4Wrb5vNJkp4oWf/OdJzP+0v6XUk+n9De8ulL+rqIOIgsn78OXCKp3R90ks6R9HxJPvfuKJ422svnp8m21P6dLJ9vkLRDmX3WhAvEpms7DO5VZPtKd46Ircj2oarA9X+WrMjMk7QT2T7SfwS2i4h+ZPs7W9efN2TvzcAvgCERsTXZftkNxhsR70TEzWTHOfbI6z+dkXI62YHAfmSb5m9sIJ5Si4AR7az/5xGxP9nupe7ARRuK2TZOZPtF/kR2UkK/kkfviHglIv4eEedFxK7AR4HPk53EABv+v23PZ4GWiFgg6X3AlWQHk7dN+TyPjvNnGtmu2tZ8nkx5+fx22uJ+nvbz+RPAaSnGfmRbOW9tIJ5SHeXzdRGxH9nWSG+y3a8NwwWierYEVgB/k7Qb2SZ61UnaQdLXyQ7qnZ0+zFuQJemybBGdTLYF0eplYLCknm3iXR4RKyXty9oPeN46vyLp05K2lNRN0mFk+1EfLOm/dHN7S7LN51fI9tn+K9kWRGk8w9Iv1TxXA2dJGpVOSRwpaYik3SQdJKkX2Qf0LbIDilZ9PwUuljQEQNL2ko5Irw+RtLukbmQHV1ex9v+hbS50SNJ70u6dc8h+TUOWz6vJ8rmbpFNZd/fky8CQ1nxOebQF8JeUz/uRFa321nmypEMlbZHy+cjUf0f5/E6KZzOyLdjebeIZvoF8nihpr5TP75M0OP0bfqyR89kFonrOAMaTHUy7iuzXTNWk87XfAJ4A/g9wdERMAYiIJ8jOinqQ7IyIXcnOLGl1NzCXbF9r666tfyTbTfRXsjODpnWw+tdJZ5mQbV5/D5gQEQ+k+T8Ejku7Ay4D7gR+m9b5Ynr/kpL+ppJ90JZLepA2IuImsrOjpqb3/pJsK6QX2fGMV8i2nrZJcVn1XUL2fzgz5cgfyA7oQrY76DbWnmF0J2vz53JgnKRXJV3STt+tF47+jezMo7HAUZFOr03HEH4KzCHLm+Hpdau7yPJqqaSW9CPpVODSFOtZwC0d/G1/Bc4nOyvpVeBC4KSIeKidv+HXZMfe5rP2rK5lJf3dDGxOls9/aLuydAzlMrID+6+n535kxx9+kPpbQlbkzusg7pprPUXSzMxsHd6CMDOzXIUWCEn9JN2azmZ5Np05s62kuyXNTc/bpGUl6QpJ89LZC6M31L9ZvTi3rRkUvQXxI+CudLbDXmRXCk4EZkTESGAGaw9MfQoYmR4TyM7KMWtUzm3r8go7BiFpK7IDUDtFyUokPU82PMISSQPJrvDdRdJV6fVNbZcrJECzTeTctmZR5CBzO5Ed6f+ZpL2Ah8ku0d+h9YORPkjbp+UHse7FZy2pbZ0PkaQJZL/C6Nu37wd33bX0bE6z6nn44YdfiYgBObOc29apdZDb6yiyQPQgOy3utIiYLelHrN3kzpN3DvF6mzcRMYlsGAvGjBkTc+bMWe9NZtUg6aV2Zjm3rVPrILfXUeQxiBayKyNbz8e/lexD9XLa/CY9Ly1ZfkjJ+wezdnwWs0bi3LamUFiBiIg/A4sk7ZKaxgLPANPJLigjPd+WXk8nuzhFyq7sXeF9tNaInNvWLIq+0c1pwA2SNiO7AvFEsqI0TdJJZPchaL0k/k7g02RjrryZljVrVM5t6/IKLRAR8Rj5Q06vNwpoOhvka0XGY1Ytzm1rBr6S2szMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeXqUe8AOothE++ouI8XLz6sCpGYmdWGtyDMzCyXC4SZmeVygTAzs1wuEGZmlqvQAiHpRUlPSnpM0pzUtq2kuyXNTc/bpHZJukLSPElPSBpdZGxmlXBuWzOoxRbEQRGxd0SMSdMTgRkRMRKYkaYBPgWMTI8JwE9qEJtZJZzb1qXVYxfTUcB16fV1wGdK2qdEZhbQT9LAOsRntqmc29alFF0gAvgfSQ9LmpDadoiIJQDpefvUPghYVPLeltRm1oic29blFX2h3P4RsVjS9sDdkp7rYFnltMV6C2UfxgkAQ4cOrU6UZhvPuW1dXqFbEBGxOD0vBX4FfBh4uXXzOj0vTYu3AENK3j4YWJzT56SIGBMRYwYMGFBk+Gbtcm5bMyisQEjqK2nL1tfAJ4GngOnA+LTYeOC29Ho6MC6d8bEvsKJ1c92skTi3rVkUuYtpB+BXklrXc2NE3CXpIWCapJOAhcDn0/J3Ap8G5gFvAicWGJtZJZzb1hQKKxAR8QKwV077X4CxOe0BfK2oeMyqxbltzcJXUpuZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwsV+EFQlJ3SY9Kuj1ND5c0W9JcSVMlbZbae6XpeWn+sKJjM9tUzmtrBrXYgvgG8GzJ9H8Al0fESOBV4KTUfhLwakTsDFyeljNrVM5r6/IKLRCSBgOHAVenaQEHA7emRa4DPpNeH5WmSfPHpuXNGorz2ppF0VsQPwTOAlan6e2A1yJiVZpuAQal14OARQBp/oq0/DokTZA0R9KcZcuWFRm7WXuqntfg3LbGU1iBkHQ4sDQiHi5tzlk0ypi3tiFiUkSMiYgxAwYMqEKkZuUrKq/BuW2Np0eBfe8PHCnp00BvYCuyX179JPVIv6YGA4vT8i3AEKBFUg9ga2B5gfGZbQrntTWNwrYgIuKciBgcEcOALwIzI+JLwD3A59Ji44Hb0uvpaZo0f2ZE5P7SMqsX57U1k3pcB3E2cLqkeWT7Yq9J7dcA26X204GJdYjNbFM5r63LKXIX0xoRcS9wb3r9AvDhnGVWAp+vRTxm1eC8tq6u7AIhaavS5SPC+1GtS3j99ddZtWrVmultt922jtGYNY4NFghJXwX+DXiLtWdfBLBTgXGZFe6qq67ivPPOo0+fPrRemiCJF154oc6RmTWGcrYgzgTeHxGvFB2MWS1deumlPP300/Tv37/eoZg1pHIOUs8H3iw6ELNaGzFiBJtvvnm9wzBrWOVsQZwD/EHSbODvrY0R8fXCojKrgYsuuoj99tuPffbZh169eq1pv+KKK+oYlVnjKKdAXAXMBJ5k7dACZp3eV7/6VQ4++GD23HNPunXzyPdmbZVTIFZFxOmFR2JWYz169OCyyy6rdxhmDaucn033pEHEBkratvVReGRmBTvooIOYNGkSS5YsYfny5WseZpYpZwviH9LzOSVtPs3VOr0bb7wRyI5FtPJprmZrbbBARMTwWgRiVmsLFiyodwhmDa2cC+XG5bVHxJTqh2NWO1Om5KfwuHG5KW/WdMrZxfShkte9gbHAI4ALhHVqDz300JrXK1euZMaMGYwePdoFwiwpZxfTaaXTkrYGri8sIrMaufLKK9eZXrFiBccff3ydojFrPJty8vebwMhqB2JWb5tvvjlz586tdxhmDaOcYxC/Zu0gfd2A3YFpRQZlVgtHHHHEmkH6Vq9ezTPPPMOxxx5b56jMGkc5xyAuLXm9CngpIloKisesZs4888w1r3v06MGOO+7I4MGD6xiRWWMp5xjE72oRiFmtfexjH6t3CGYNrd0CIWkBa3cttRURMaKYkMyKNXz48DW7ltqSxPz582sckVlj6mgLYkyb6W7AsWT3h3i0sIjMCjZnzpx1plevXs20adO49NJLGTVqVJ2iMms87RaIiPgLgKRuwPHAt4DHgMMi4pnahGdWfdtttx2QFYbrr7+e73//++y9997ccccd7L777nWOzqxxdLSLqSfwFeBfgPuBoyLC297W6b3zzjtce+21XH755RxwwAHcdtttjBjhPaZmbXW0i2kB2VlLPwQWAntJ2qt1ZkT8suDYzAoxfPhwevTowTe/+U2GDh3K448/zuOPP75m/tFHH13H6MwaR0cF4rdkB6n3So9SAbhAWKd0yCGHIGm9wgDZQWoXCLNMR8cgTqhhHGY1M3ny5HqHYNYp+D6LZmaWywXCzMxyuUCYmVmuDRYISZtLOlfSf6fpkZIOL+N9vSU9KOlxSU9LuiC1D5c0W9JcSVMlbZbae6XpeWn+sMr+NLOOvfnmm1x44YWccsopAMydO5fbb799g+9zbluzKGcL4mfA34GPpOkW4LtlvO/vwMERsRewN3CopH2B/wAuj4iRwKvASWn5k4BXI2Jn4PK0nFlhTjzxRHr16sUDDzwAwODBg/nOd75Tzlud29YUyikQIyLiEuAdgIh4C8gfyKZEZN5Ikz3TI4CDgVtT+3XAZ9Lro9I0af5YtTdgjlkVzJ8/n7POOouePXsC0KdPHyLaG35sLee2NYtyCsTbkvqQBu6TNILsF9QGSeou6TFgKXA3MB94LSJWpUVagEHp9SBgEUCavwLYLqfPCZLmSJqzbNmycsIwy7XZZpvx1ltvrRm4b/78+fTq1aus9zq3rRmUUyDOB+4Chki6AZgBnFVO5xHxbkTsDQwGPgzslrdYes77RbXez7mImBQRYyJizIABA8oJwyzXBRdcwKGHHsqiRYv40pe+xNixY7nkkkvKeq9z25pBOfeDuFvSI8C+ZIn+jYh4ZWNWEhGvSbo39dFPUo/0S2owsDgt1gIMAVok9QC2BpZvzHrMNsYnPvEJRo8ezaxZs4gIfvSjH9G/f/+N6sO5bV1ZR4P1jW7TtCQ9D5U0NCIe6ahjSQOAd9IHqA9wCNnBuXuAzwE3A+OB29JbpqfpB9L8mVHODmGzjfTII+um7sCBAwFYuHAhCxcuZPTotqm/Lue2NYuOtiB+0MG81gNyHRkIXCepO9murGkRcbukZ4CbJX2X7L4S16TlrwGulzSP7NfVF8v5A8w21hlnnNHuPEnMnDlzQ104t60pdDQW00GVdBwRTwDr3X0lIl4g22fbtn0l8PlK1mlWjnvuuaei9zu3rVls8BiEpN7APwEHkG05/B74aUp6s05r5cqV/PjHP+b+++9HEgceeCCnnnoqvXv3rndoZg1hgwUCmAL8FbgyTR8HXI9/EVknN27cOLbccktOO+00AG666SaOP/54brnlljpHZtYYyikQu6QrRlvdI+nxdpc26ySef/75de4HcdBBB7HXXm1vfWLWvMq5DuLRNIwAAJL2Af63uJDMamPUqFHMmjVrzfTs2bPZf//96xiRWWMpZwtiH2CcpIVpeijwrKQnyUYd+EBh0ZkVaPbs2UyZMoWhQ4cC2Wmuu+22G3vuueeaq6vNmlk5BeLQwqMwq4O77rqrw/nDhg2rTSBmDaqcK6lfkrQN2ZWgPUraO7xQzqzR7bjjjrz66qssWrSIVatWrWnf0IVyZs2inNNcLwROIBuMrPXqz3IulDNraOeeey6TJ09mxIgRa3YplXmhnFlTKGcX07FkQ36/XXQwZrU0bdo05s+fz2abbVbvUMwaUjlnMT0F9Cs6ELNa22OPPXjttdfqHYZZwypnC+IislNdn6LkPhARcWRhUZnVwDnnnMOoUaPYY4891rkPxPTp0+sYlVnjKKdAXEc2UuWTwOpiwzGrnfHjx3P22Wez55570q1bORvTZs2lnALxSkRcUXgkZjXWv39/vv71r9c7DLOGVU6BeFjSRWRj2pfuYvJprtapffCDH+Scc87hyCOPXGcXk09zNcuUUyBahzXet6TNp7k2gWET76i4jxcvPqwKkRTj0UcfBVhnuA2f5mq2VjkXylV0X4h6qfTLrZG/2Kw6Kr0vhFlXV84WBJIOA94PrBkoPyL+raigzGrljjvu4Omnn2blyrW3NznvvPPqGJFZ49jgqRuSfgp8ATgNENl9IHYsOC6zwp166qlMnTqVK6+8kojglltu4aWXXqp3WGYNo5wtiP0i4gOSnoiICyT9APhl0YGZlaOSXYmLf/kb3l66gA984AOcf/75nHHGGRx99NFVjM6scyvn5O+30vObkt4LvAMMLy4ks9pQj2yIjc0335zFixfTs2dPFixYUOeozBpHOVsQt0vqB3wfeITsDKb/LjQqsxroM+JDvPbaa3zrW99i9OjRSOKUU06pd1hmDaOcs5guTC9/Iel2oHdErCg2LLPi9dv/OPr168cxxxzD4YcfzsqVK9l6663rHZZZw2h3F5OkD0l6T8n0OGAacKGkbWsRnFkR/r7kj7z7xqtrpqdMmcKxxx7Lueeey/Lly+sYmVlj6egYxFXA2wCSPgpcDEwBVgCTig/NrBjLf/Nf0D3beL7vvvuYOHEi48aNY+utt2bChAl1js6scXS0i6l7RLT+nPoCMCkifkG2q+mx4kMzK0asfpfufbYEYOrUqUyYMIFjjjmGY445hr333rvO0Zk1jo62ILpLai0gY4HS8QfKusDOrCHFamL1uwDMmDGDgw9eO2pM6a1HzZpdR1/0NwG/k/QK2amuvweQtDPZbiazTqnvbh/j5Rsn0q3PVuyyRR8OPPBAAObNm+eD1GYl2t2CiIh/B84AJgMHRESUvOe0DXUsaYikeyQ9K+lpSd9I7dtKulvS3PS8TWqXpCskzZP0hCQPqWmF2Hq/L7DNQSexxZ5juf/++9fcj3r16tVceeWVG3y/c9uaRYe7iiJiVk7bH8vsexVwRkQ8ImlLsmHD7wZOAGZExMWSJgITgbOBTwEj02Mf4Cfp2azqeg3aFYC+ffuuaXvf+95X7tud29YUCruNVkQsab1nRET8FXgWGAQcRXaXOtLzZ9Lro4ApkZkF9JM0sKj4zDaVc9uaRU0ONksaRnZfidnADhGxBLIPmqTt02KDgEUlb2tJbUva9DUBmAAwdOjQQuO26utqw7A7t60rK/xGvJK2AH4BfDMiXu9o0Zy2WK8hYlJEjImIMQMGDKhWmGYbzbltXV2hBUJST7IP0A0R0ToC7Mutm9fpeWlqbwGGlLx9MLC4yPjMNpVz25pBYQVC2akh1wDPRsRlJbOmA+PT6/HAbSXt49IZH/sCK1o3180aiXPbmkWRxyD2B44Hniy58vrbZEN2TJN0ErCQ7AZEAHcCnwbmAW8CJxYYm1klnNvWFAorEBFxP/n7XiG7Mrvt8gF8rah4zKrFuW3NovCD1GZm1jm5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWa6a3A/CaqOr3WvBzOrLBcKsyfmHhbXHu5jMzCyXC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPLVViBkHStpKWSnipp21bS3ZLmpudtUrskXSFpnqQnJI0uKi6zSjm3rVkUecOgycB/AlNK2iYCMyLiYkkT0/TZwKeAkemxD/CT9Nyl+UYtndZknNvWBArbgoiI+4DlbZqPAq5Lr68DPlPSPiUys4B+kgYWFZtZJZzb1ixqfQxih4hYApCet0/tg4BFJcu1pDazzsK5bV1OoxykVk5b5C4oTZA0R9KcZcuWFRyWWcWc29Zp1bpAvNy6eZ2el6b2FmBIyXKDgcV5HUTEpIgYExFjBgwYUGiwZhvBuW1dTq0LxHRgfHo9HritpH1cOuNjX2BF6+a6WSfh3LYup7CzmCTdBHwc6C+pBTgfuBiYJukkYCHw+bT4ncCngXnAm8CJRcVlVinntjWLwgpERBzXzqyxOcsG8LWiYjGrJue2NYsir4MwM2tIvgapPC4QZp2Mv9ysVhrlNFczM2swLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcvpLazKxClV7dDutf4d4IV8x7C8LMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnl8mmuZlZVRZzyafXhLQgzM8vlLQgza3iNcNFYM/IWhJmZ5XKBMDOzXC4QZmaWywXCzMxyuUCYmVkuFwgzM8vlAmFmZrkaqkBIOlTS85LmSZpY73jMqsW5bZ1RwxQISd2B/wI+BewOHCdp9/pGZVY557Z1Vg1TIIAPA/Mi4oWIeBu4GTiqzjGZVYNz2zolRUS9YwBA0ueAQyPi5DR9PLBPRPxzm+UmABPS5C7A85u4yv7AK5v43lr0V0SfzRhjJf3tGBEDKg3AuV14f0X02ej9VdpnWbndSGMxKadtveoVEZOASRWvTJoTEWMq7aeo/orosxljLOJv3pQwctqaNredh43bZ1uNtIupBRhSMj0YWFynWMyqybltnVIjFYiHgJGShkvaDPgiML3OMZlVg3PbOqWG2cUUEask/TPwG6A7cG1EPF3gKivelC+4vyL6bMYYi/ibN4pzu/D+iuiz0fsrqs91NMxBajMzayyNtIvJzMwaiAuEmZnlasoCUc1hDyRdK2mppKeqFNsQSfdIelbS05K+UYU+e0t6UNLjqc8LqhRrd0mPSrq9Cn29KOlJSY9JmlOl+PpJulXSc+nf8yPV6LdRVXs4j0bP7c6Q16m/quZ2TfM6IprqQXaQcD6wE7AZ8DiwewX9fRQYDTxVpfgGAqPT6y2BP1YSX+pHwBbpdU9gNrBvFWI9HbgRuL0Kfb0I9K/y//V1wMnp9WZAv1rnW60e1c7r1GdD53ZnyOvUX1Vzu5Z53YxbEFUd9iAi7gOWVyu4iFgSEY+k138FngUGVdhnRMQbabJnelR0doKkwcBhwNWV9FMUSVuRfcFdAxARb0fEa/WNqlBVH86j0XPbeV18XjdjgRgELCqZbqHCL+CiSBoGjCL7ZVRpX90lPQYsBe6OiEr7/CFwFrC60tiSAP5H0sNpyIlK7QQsA36WdhdcLalvFfptVJ0mr6F6ud0J8hqqm9s1zetmLBBlDXtQb5K2AH4BfDMiXq+0v4h4NyL2JruK98OS9qggtsOBpRHxcKVxldg/IkaTjXj6NUkfrbC/HmRo7w8zAAADyElEQVS7R34SEaOAvwFdeZjtTpHXUN3c7gR5DdXN7ZrmdTMWiIYf9kBST7IP0A0R8ctq9p02R+8FDq2gm/2BIyW9SLYr42BJP68wrsXpeSnwK7JdJpVoAVpKflHeSvbB6qoaPq+huNxu1LxOsVUzt2ua181YIBp62ANJItu/+GxEXFalPgdI6pde9wEOAZ7b1P4i4pyIGBwRw8j+/WZGxJcriK+vpC1bXwOfBCo6cyYi/gwskrRLahoLPFNJnw2uofMaqp/bjZ7XKa6q5nat87phhtqolajysAeSbgI+DvSX1AKcHxHXVBDi/sDxwJNp3yrAtyPizgr6HAhcp+zGNd2AaRFRlVP4qmQH4FfZ9wc9gBsj4q4q9HsacEP6wnwBOLEKfTakauc1dIrcbvS8hmJyu2Z57aE2zMwsVzPuYjIzszK4QJiZWS4XCDMzy+UCYWZmuVwgzMwslwtEHUl6Y8NLrVn2XyWdWY3+N2a9afmPV2tkS2sOzu2uwQXCzMxyuUA0GElHSJqdBuL6raQdSmbvJWmmpLmSTil5z7ckPSTpiY0ZEz/9erq3ZGz5G9LVrq33FnhO0v3A0SXv6avsPgEPpRiPSu2nS7o2vd5T0lOSNq/038O6Dud2J1TUOOJ+lDWu+xs5bduw9gLGk4EfpNf/SjbGfx+gP9nIne8lu3R/Etlgbd2A24GPttd/aTvZVbIryMbt6QY8ABwA9E79j0z9TiONjQ98D/hyet2PbEz/vun99wGfBeaQDVBW939jP+rzcG53jUfTDbXRCQwGpkoaSHYzkAUl826LiLeAtyTdQzbo1wFkH6RH0zJbkCX/fWWu78GIaAFIwx8MA94AFkTE3NT+c6B1mOJPkg1o1rrPuDcwNCKelXQC8ARwVUT870b91dYMnNudjAtE47kSuCwipkv6ONmvq1Ztx0UJsl9BF0XEVZu4vr+XvH6XtTnR3hgsAo6JiOdz5o0k+wC+dxNjsa7Nud3J+BhE49ka+FN6Pb7NvKOU3Yd3O7JN6IfIBmf7irIx9pE0SNL2FcbwHDBc0og0fVzJvN8Ap5Xszx2VnrcGfkR2t6vtJH2uwhis63FudzLegqivzdMoma0uI/tVdYukPwGzgOEl8x8E7gCGAhdGNs78Ykm7AQ+kvH4D+DLZHbY2SUSsVHbnqzskvQLcD7TeiOVCsrtuPZE+SC8ChwOXAz+OiD9KOgm4R9J9kY2Bb83Hud0FeDRXMzPL5V1MZmaWywXCzMxyuUCYmVkuFwgzM8vlAmFmZrlcIMzMLJcLhJmZ5fr/f+tMSfg/OOsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(\"{}Train Data anf Test Data Info{}\".format(\"*\"*20, \"*\"*20))\n",
    "print(\"==> Label list : \"+(\"\\n    {}\"*7).format(*[(i,j) for j,i in label_dict.items()]))\n",
    "inds, nums = np.unique(y[data.train_mask].numpy(), return_counts=True)\n",
    "plt.figure(1)\n",
    "plt.subplot(121)\n",
    "plt.bar(x=inds, height=nums, width=0.8, bottom=0, align='center')\n",
    "plt.xticks(ticks=range(7))\n",
    "plt.xlabel(xlabel=\"Label Index\")\n",
    "plt.ylabel(ylabel=\"Sample Num\")\n",
    "plt.ylim((0, 600))\n",
    "plt.title(label=\"Train Data Statics\")\n",
    "inds, nums = np.unique(y[data.test_mask].numpy(), return_counts=True)\n",
    "plt.subplot(122)\n",
    "plt.bar(x=inds, height=nums, width=0.8, bottom=0, align='center')\n",
    "plt.xticks(ticks=range(7))\n",
    "plt.xlabel(xlabel=\"Label Index\")\n",
    "plt.ylabel(ylabel=\"Sample Num\")\n",
    "plt.ylim((0, 600))\n",
    "plt.title(label=\"Test Data Statics\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyG构建GCN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "使用官方教程上的例子，利用MessagePassing类来构建GCN层。\n",
    "初始化阶段：\n",
    "Input :\n",
    "    in_channels : (int)输入的节点特征维度\n",
    "    out_channels : (int)节点输出的特征维度\n",
    "Output :\n",
    "    None\n",
    "\n",
    "forward阶段：\n",
    "Input :\n",
    "    x : (Tensor)输入的节点特征矩阵，shape(N, in_channels)\n",
    "    edge_index : (LongTensor)输入的边矩阵，shape(2, E)\n",
    "Output :\n",
    "    out : (Tensor)输出层的节点logits，shape(N, num_class)\n",
    "'''\n",
    "class GCNConv(MessagePassing):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(GCNConv, self).__init__(aggr='add')  # \"Add\" aggregation.\n",
    "        self.lin = torch.nn.Linear(in_channels, out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        # x has shape [N, in_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "\n",
    "        # Step 1: Add self-loops to the adjacency matrix.\n",
    "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
    "\n",
    "        # Step 2: Linearly transform node feature matrix.\n",
    "        x = self.lin(x)  # (N, in_channels) -> (N, out_channels)\n",
    "\n",
    "        # Step 3-5: Start propagating messages.\n",
    "        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)\n",
    "\n",
    "    def message(self, x_j, edge_index, size):\n",
    "        # x_j has shape [E, out_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "\n",
    "        # Step 3: Normalize node features.\n",
    "        row, col = edge_index  # [E,], [E,]\n",
    "        deg = degree(row, size[0], dtype=x_j.dtype)  # [N, ]\n",
    "        deg_inv_sqrt = deg.pow(-0.5)   # [N, ]\n",
    "        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # [E, ]\n",
    "\n",
    "        return norm.view(-1, 1) * x_j\n",
    "\n",
    "    def update(self, aggr_out):\n",
    "        # aggr_out has shape [N, out_channels]\n",
    "\n",
    "        # Step 5: Return new node embeddings.\n",
    "        return aggr_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "构建模型，使用两层GCN，第一层GCN使得节点矩阵\n",
    "        (N, in_channel) -> (N, 16)\n",
    "第二层GCN使得节点矩阵\n",
    "        (N, 16) -> (N, num_class)\n",
    "激活函数使用relu函数，网络最后对节点的各类别score使用softmax归一化，\n",
    "返回归一化后的Tensor。\n",
    "'''\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, feat_dim, num_class):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(feat_dim, 16)\n",
    "        self.conv2 = GCNConv(16, num_class)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "'''\n",
    "开始训练模型\n",
    "'''\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = Net(feat_dim=data.num_node_features, num_class=7).to(device)         # Initialize model\n",
    "data = data.to(device)                                                       \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Initialize optimizer and training params\n",
    "\n",
    "for epoch in range(200):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    # Get output\n",
    "    out = model(data)\n",
    "    \n",
    "    # Get loss\n",
    "    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n",
    "    _, pred = out.max(dim=1)\n",
    "    \n",
    "    # Get predictions and calculate training accuracy\n",
    "    correct = float(torch.masked_select(pred, data.train_mask.byte()).eq(torch.masked_select(data.y, data.train_mask.byte())).sum().item())\n",
    "    acc = correct / data.train_mask.sum().item()\n",
    "    print('[Epoch {}/200] Loss {:.4f}, train acc {:.4f}'.format(epoch, loss.cpu().detach().data.item(), acc))\n",
    "    \n",
    "    # Backward\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    # Evaluation on test data every 10 epochs\n",
    "    if (epoch+1) % 10 == 0:\n",
    "        model.eval()\n",
    "        _, pred = model(data).max(dim=1)\n",
    "        correct = float(pred[data.test_mask.byte()].eq(data.y[data.test_mask.byte()]).sum().item())\n",
    "        acc = correct / data.test_mask.sum().item()\n",
    "        print('Accuracy: {:.4f}'.format(acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
